classdef RigidICP < handle
    %RigidICP - Iterated closest point for rigid transformation
    %   See below for properties
    %
    % 
    % Copyright (C) Russell H. Taylor 2013
    % For use with CIS I only
    % Do not redistribute without written permission from Russell Taylor
    
    properties
        Mesh                % Handle for surface mesh
        Tree                % Handle for covariance tree
        F = Frame()         % Current estimate of cartesian transformation
        A = vct3Array()     % Surface samples
        N = 0               % Number of surface samples
        FA = vct3Array()    % F*A
        Cp = vct3Array()    % Current closest point matches to A
        Ct = []             % Triangle indices corresponting to Cp
        Cd = []             % Distances from the A(i) to Cp(i)
        MeanErr             % Mean(abs(Cd))
        MedianErr           % Median (Cd)
        RMSErr              % sqrt(Cd'*Cd/N)
        MaxErr              % Max element of Cd
        FilterMask = []     % 0/1 vector saying which A's and Cp' are "good"
        AFiltered = vct3Array() % Subset of A with good matches
        CpFiltered = vct3Array()% Corresponding subset of Cp
        CdFiltered = []     % Corresponding subset of Cd
        CtFiltered = []     % Corresponding subset of Ct
        NFiltered = 0        % Number of matches
        FilteredMedianErr   % median(CdFiltered)
        FilteredMeanErr     % Mean (CdFiltered)
        FilteredRMSErr      % sqrt (CdFiltered'*CdFiltered/N)
        FilteredMaxErr      % Max (CdFiltered)
        Niter               % Number of iterations
        Fhistory = {};      %
        Stats = []          % accumulate per-iteration stats
                            % [...;[Nm,FilteredMeanErr,FilteredRMSErr,FilteredMaxErrF,MeanErr,RMSErr,MaxErr];...]
        % Control Parameters below here
        OutlierFactor = 3   % Reject matches more that this times MeanAbsErr
        OutlierMinThresh=0  % Insist that threshold be at least this much
        OutlierThresh = inf % Most recently used threshold value
        MinMatches = 3      % Must have this many matches
        MinMatchFactor = 0.8 % Also must have at least this percent of matches 
        MinIter = 5         % Go at least this many iterations unless error is 0
        MaxIter = 100       % Go at most this many iterations
        Tiny = 0.0000001    % Very small number
        MeanErrThresh = 0.01        % Stop if mean error is less than this
        MeanErrChangeThresh = 0.01  % Stop if mean error doesn't change by this for 
                                    % MinIter-1 steps
        % Trace Flags
        TraceMatch          % If non-zero does some tracing of matches
    end
    
    methods
        function ICP = RigidICP(Mesh,Tree,A,F)
            ICP.Mesh = Mesh;
            if nargin<1 
                ICP.Tree=TriangleCovTree(ICP.Mesh);
            else
                ICP.Tree=Tree;
            end
            
            if nargin>2; ICP.A = A; end         
            if nargin>3; ICP.F = F; end 
            ICP.TraceMatch = 0;
            ICP.SetSamples();
        end
        
        function SetSamples(ICP,A,F)
            if nargin>1; ICP.A = A; end
            if nargin>2; ICP.F = F; end
            ICP.N = ICP.A.NumEl();
            if ICP.N==0; return; end
            ICP.FA = ICP.F*ICP.A;
            ICP.Ct = ones(1,ICP.N);
            ICP.Cd = zeros(1,ICP.N);
            [~,p,q,r] = ICP.Mesh.Triangle(1);
            for i =1:ICP.N
                [ICP.Cp(i),ICP.Cd(i)] = TriangleClosestPoint(ICP.FA(i),p,q,r);
            end
        end
        
        function Match(ICP)
            if ICP.TraceMatch
                fprintf('Matching  ');
            end
            for i = 1:ICP.N
                Cti = ICP.Ct(i);
                [~,p,q,r] = ICP.Mesh.Triangle(ICP.Ct(i));
                [Cpi,Cdi] = TriangleClosestPoint(ICP.FA(i),p,q,r);
                [ICP.Cp(i),ICP.Ct(i),ICP.Cd(i)] = ICP.Tree.SearchForClosestPoint(ICP.FA(i),Cpi,Cti,Cdi);
                if bitand(ICP.TraceMatch,2)
                    % fprintf(' %d',ICP.Ct(i));
                    fprintf('.');
                end 
            end
            
            ICP.MedianErr = median(ICP.Cd);
            ICP.MeanErr = mean(ICP.Cd);
            ICP.RMSErr = sqrt((ICP.Cd*ICP.Cd')/ICP.N);
            ICP.MaxErr = norm(ICP.Cd,inf);
            
            if ICP.TraceMatch
                fprintf('\nErrors: Median=%f Mean=%f RMS=%f Max=%f\n', ...
                        ICP.MedianErr,ICP.MeanErr,ICP.RMSErr,ICP.MaxErr);
            end
        end
    
        function [Nm,Mmed,Me,Mrms,Mmx] = Filter(ICP,Thresh)
            if nargin < 2
                Thresh = ICP.OutlierThresh;
            else
                ICP.OutlierThresh = Thresh;
            end
            ICP.FilterMask = ICP.Cd<=Thresh;
            ICP.AFiltered  = vct3Array(ICP.A.el(:,ICP.FilterMask));
            ICP.CpFiltered = vct3Array(ICP.Cp.el(:,ICP.FilterMask));
            ICP.CdFiltered = ICP.Cd(ICP.FilterMask);
            ICP.CtFiltered = ICP.Ct(ICP.FilterMask);
            Nm = size(ICP.CdFiltered,2); ICP.NFiltered = Nm;
            Mmed = median(ICP.CdFiltered); ICP.FilteredMedianErr = Mmed;
            Me = mean(ICP.CdFiltered); ICP.FilteredMeanErr = Me;
            Mrms = sqrt((ICP.CdFiltered*ICP.CdFiltered')/ICP.N); ICP.FilteredRMSErr = Mrms;
            Mmx = norm(ICP.CdFiltered,inf); ICP.FilteredMaxErr = Mmx;
            if Nm<ICP.N || 1  % always print for now
                fprintf('  %d good matches out of %d\n',Nm,ICP.N);
            end
        end
        
        function [Nm,Mmed,Me,Mrms,Mmx] = Solve(ICP)
            ICP.Stats = [];
            for iter = 1:ICP.MaxIter
                ICP.Niter = iter;
                fprintf('\nIteration %d:\n', ICP.Niter);
                ICP.Match();
                Thresh = max(ICP.OutlierFactor*ICP.MedianErr,ICP.OutlierMinThresh);
                [Nm,Mmed,Me,Mrms,Mmx] = ICP.Filter(Thresh);
                if Nm<ICP.MinMatches || Nm < ICP.N - ICP.MinMatchFactor*ICP.N
                    fprintf('Not enough matches for ICP: %d of %d\n',Nm,ICP.N);
                end
                ICP.F = FindBestRigidTransformation(ICP.AFiltered,ICP.CpFiltered);
                ICP.FA = ICP.F*ICP.A;
                ICP.Stats(ICP.Niter,:) = ICP.CurStats();
                ICP.Fhistory{ICP.Niter} = ICP.F;
                disp(ICP.CurStats());
                % disp(ICP.F);               
                if Me < ICP.Tiny; return; end
                if ICP.Niter > 1
                    fprintf('Change = %f\n',Me-ICP.Stats(ICP.Niter-1,2));
                end
                if ICP.Niter >= ICP.MinIter
                    if Me < ICP.MeanErrThresh; return; end
                    CanQuit = 1;
                    for i=0:(ICP.MinIter-2)
                        Change = ICP.Stats(ICP.Niter-i,2)-ICP.Stats(ICP.Niter-i-1,2);
                        if abs(Change) > ICP.MeanErrChangeThresh
                            CanQuit = 0; 
                            break; 
                        end
                    end
                    if CanQuit; return; end
                end                    
            end        
        end
        
        function Test(ICP,N,Ds,DRD,Dp,Nout,DpOut)
            A = ICP.Mesh.Sample(-Ds,Ds,N);
            if nargin > 5
                % Add outliers
                if nargin < 7; DpOutMin=-10*(Ds+1); DpOutMax=10*(Ds+1); end
                if isa(DpOut,'vct3BoundingBox')
                    A = [A,DpOut.Sample(Nout)];
                else
                    A = [A , ICP.Mesh.Sample(-DpOut,DpOut,Nout)];
                end
            end
            ICP.OutlierMinThresh = Ds;
            Fi = Frame(RotMx.randD(DRD,DRD,DRD),vct3.rand(-Dp,Dp));
            ICP.SetSamples(Fi*A,Frame());
            Tstart = clock;
            ICP.Solve();
            fprintf('Elapsed time = %f\n',etime(clock,Tstart));
            ICP.ShowStats();
            ICP.F
            Fi
            Fe = ICP.F*Fi
            [Ax,An]=Fe.R.AxisAngle();
            [Ax.el', An, An*180/pi]
        end
        
        function stats = CurStats(ICP)
            stats = [ICP.NFiltered, ...
                     ICP.FilteredMedianErr,  ICP.FilteredMeanErr,ICP.FilteredRMSErr,ICP.FilteredMaxErr, ...
                     ICP.MedianErr, ICP.MeanErr, ICP.RMSErr, ICP.MaxErr];
        end
        
        function ShowStats(ICP)
            disp(ICP.CurStats());
        end
    end
    
end

